import jax.numpy as jnp
import chex
from rl import network
from quantity import walls, true_q


def get_true_error(agent_state):
  obs = jnp.eye(19**2)
  q_values = network.apply(agent_state.q_params, obs)
  max_q = jnp.max(q_values, -1)
  max_q = jnp.reshape(max_q, (19, 19))
  chex.assert_equal_shape([walls, true_q, max_q])
  diff_q = (max_q - true_q) * walls
  return {'true_error': jnp.mean(diff_q**2)}
